#
# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#  http://aws.amazon.com/apache2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
#

"""
Handshake tests using Openssl s_client against s2nd
Openssl 1.1.0 removed SSLv3, 3DES, an RC4, so we won't have coverage there.
"""

import argparse
import os
import sys
import subprocess
import itertools
import multiprocessing
from os import environ
from multiprocessing.pool import ThreadPool
from s2n_test_constants import *

PROTO_VERS_TO_S_CLIENT_ARG = {
    S2N_TLS10 : "-tls1",
    S2N_TLS11 : "-tls1_1",
    S2N_TLS12 : "-tls1_2",
}

S_CLIENT_SUCCESSFUL_OCSP="OCSP Response Status: successful"

def cleanup_processes(*processes):
    for p in processes:
        p.kill()
        p.wait()

def try_handshake(endpoint, port, cipher, ssl_version, server_cert=None, server_key=None, ocsp=None, sig_algs=None, curves=None, resume=False,
        prefer_low_latency=False, enter_fips_mode=False, client_auth=None, client_cert=DEFAULT_CLIENT_CERT_PATH, client_key=DEFAULT_CLIENT_KEY_PATH):
    """
    Attempt to handshake against s2nd listening on `endpoint` and `port` using Openssl s_client

    :param int endpoint: endpoint for s2nd to listen on
    :param int port: port for s2nd to listen on
    :param str cipher: ciphers for Openssl s_client to offer. See https://www.openssl.org/docs/man1.0.2/apps/ciphers.html
    :param int ssl_version: SSL version for s_client to use
    :param str server_cert: path to certificate for s2nd to use
    :param str server_key: path to private key for s2nd to use
    :param str ocsp: path to OCSP response file for stapling
    :param str sig_algs: Signature algorithms for s_client to offer
    :param str curves: Elliptic curves for s_client to offer
    :param bool resume: True if s_client should try to reconnect to s2nd and reuse the same TLS session. False for normal negotiation.
    :param bool prefer_low_latency: True if s2nd should use 1500 for max outgoing record size. False for default max.
    :param bool enter_fips_mode: True if s2nd should enter libcrypto's FIPS mode. Libcrypto must be built with a FIPS module to enter FIPS mode.
    :param bool client_auth: True if the test should try and use client authentication
    :param str client_cert: Path to the client's cert file
    :param str client_key: Path to the client's private key file
    :return: 0 on successfully negotiation(s), -1 on failure
    """

    # Override certificate for ECDSA if unspecified. We can remove this when we
    # support multiple certificates
    if server_cert is None and "ECDSA" in cipher:
        server_cert = TEST_ECDSA_CERT
        server_key = TEST_ECDSA_KEY
   
    # Fire up s2nd
    s2nd_cmd = ["../../bin/s2nd"]

    if server_cert is not None:
        s2nd_cmd.extend(["--cert", server_cert])
    if server_key is not None:
        s2nd_cmd.extend(["--key", server_key])
    if ocsp is not None:
        s2nd_cmd.extend(["--ocsp", ocsp])
    if prefer_low_latency == True:
        s2nd_cmd.append("--prefer-low-latency")
    if client_auth is not None:
        s2nd_cmd.append("-m")
        s2nd_cmd.extend(["-t", client_cert])

    s2nd_cmd.extend([str(endpoint), str(port)])
    
    s2nd_ciphers = "test_all"
    if enter_fips_mode == True:
        s2nd_ciphers = "test_all_fips"
        s2nd_cmd.append("--enter-fips-mode")
    if "ECDSA" in cipher:
        s2nd_ciphers = "test_all_ecdsa"
    s2nd_cmd.append("-c")
    s2nd_cmd.append(s2nd_ciphers)
    
    s2nd = subprocess.Popen(s2nd_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)

    # Make sure it's running
    s2nd.stdout.readline()

    s_client_cmd = ["openssl", "s_client", PROTO_VERS_TO_S_CLIENT_ARG[ssl_version],
            "-connect", str(endpoint) + ":" + str(port)]
    if cipher is not None:
        s_client_cmd.extend(["-cipher", cipher])
    if sig_algs is not None:
        s_client_cmd.extend(["-sigalgs", sig_algs])
    if curves is not None:
        s_client_cmd.extend(["-curves", curves])
    if resume == True:
        s_client_cmd.append("-reconnect")
    if client_auth is not None:
        s_client_cmd.extend(["-key", client_key])
        s_client_cmd.extend(["-cert", client_cert])
    if ocsp is not None:
        s_client_cmd.append("-status")

    # Fire up s_client
    s_client = subprocess.Popen(s_client_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)

    # Validate that s_client resumes successfully against s2nd
    if resume is True:
        seperators = 0
        for line in s2nd.stdout:
            line = line.decode("utf-8").strip()
            if line.startswith("Resumed session"):
                seperators += 1

            if seperators == 5:
                break

        if seperators != 5:
            cleanup_processes(s2nd, s_client)
            return -1

    # Validate that s_client accepted s2nd's stapled OCSP response
    if ocsp is not None:
        ocsp_success = False
        for line in s_client.stdout:
            line = line.decode("utf-8").strip()
            if S_CLIENT_SUCCESSFUL_OCSP in line:
                ocsp_success = True
                break
        if not ocsp_success:
            cleanup_processes(s2nd, s_client)
            return -1

    # Write the cipher name towards s2n
    s_client.stdin.write((cipher + "\n").encode("utf-8"))
    s_client.stdin.flush()

    # Read it
    found = 0
    for line in range(0, 10):
        output = s2nd.stdout.readline().decode("utf-8")
        if output.strip() == cipher:
            found = 1
            break

    if found == 0:
        cleanup_processes(s2nd, s_client)
        return -1

    # Write the cipher name from s2n
    s2nd.stdin.write((cipher + "\n").encode("utf-8"))
    s2nd.stdin.flush()
    found = 0
    for line in range(0, 512):
        output = s_client.stdout.readline().decode("utf-8")
        if output.strip() == cipher:
            found = 1
            break

    if found == 0:
        cleanup_processes(s2nd, s_client)
        return -1

    cleanup_processes(s2nd, s_client)

    return 0

def cert_path_to_str(cert_path):
    # Converts a path to a cert into a string usable for printing to test output
    # Example: "./test_certs/rsa_2048_sha256_client_cert.pem" => "RSA-2048-SHA256"
    return '-'.join(cert_path[cert_path.rfind('/')+1:].split('_')[:3]).upper()

def print_result(result_prefix, return_code):
    suffix = ""
    if return_code == 0:
        if sys.stdout.isatty():
            suffix = "\033[32;1mPASSED\033[0m"
        else:
            suffix = "PASSED"
    else:
        if sys.stdout.isatty():
            suffix = "\033[31;1mFAILED\033[0m"
        else:
            suffix ="FAILED"
            
    print(result_prefix + suffix)

def create_thread_pool():
    threadpool_size = multiprocessing.cpu_count() * 2  #Multiply by 2 since performance improves slightly if CPU has hyperthreading
    print("\tCreating ThreadPool of size: " + str(threadpool_size))
    threadpool = ThreadPool(processes=threadpool_size)
    return threadpool

def run_handshake_test(host, port, ssl_version, cipher, fips_mode, use_client_auth, client_cert_path, client_key_path):
    cipher_name = cipher.openssl_name
    cipher_vers = cipher.min_tls_vers

    # Skip the cipher if openssl can't test it. 3DES/RC4 are disabled by default in 1.1.0
    if not cipher.openssl_1_1_0_compatible:
        return 0

    if ssl_version < cipher_vers:
        return 0
    
    client_cert_str=str(use_client_auth)
    
    if (use_client_auth is not None) and (client_cert_path is not None):
        client_cert_str = cert_path_to_str(client_cert_path)

    ret = try_handshake(host, port, cipher_name, ssl_version, enter_fips_mode=fips_mode, client_auth=use_client_auth, client_cert=client_cert_path, client_key=client_key_path)
    
    result_prefix = "Cipher: %-28s ClientCert: %-16s Vers: %-8s ... " % (cipher_name, client_cert_str, S2N_PROTO_VERS_TO_STR[ssl_version])
    print_result(result_prefix, ret)
    
    return ret

def handshake_test(host, port, test_ciphers, fips_mode, use_client_auth=None, use_client_cert=None, use_client_key=None):
    """
    Basic handshake tests using all valid combinations of supported cipher suites and TLS versions.
    """
    print("\n\tRunning handshake tests:")
    
    failed = 0
    for ssl_version in [S2N_TLS10, S2N_TLS11, S2N_TLS12]:
        print("\n\tTesting ciphers using client version: " + S2N_PROTO_VERS_TO_STR[ssl_version])
        threadpool = create_thread_pool()
        port_offset = 0
        results = []
        
        for cipher in test_ciphers:
            async_result = threadpool.apply_async(run_handshake_test, (host, port + port_offset, ssl_version, cipher, fips_mode, use_client_auth, use_client_cert, use_client_key))
            port_offset += 1
            results.append(async_result)

        threadpool.close()
        threadpool.join()
        for async_result in results:
            if async_result.get() != 0:
                failed = 1

    return failed
    

def client_auth_test(host, port, test_ciphers, fips_mode):
    failed = 0

    print("\n\tRunning client auth tests:")

    if fips_mode:
        print("\t\033[33;1mSKIPPED\033[0m - Client Auth not supported in FIPS mode")
        return 0

    for filename in os.listdir(TEST_CERT_DIRECTORY):
        if "client_cert" in filename and "rsa" in filename:
            client_cert_path = TEST_CERT_DIRECTORY + filename
            client_key_path = TEST_CERT_DIRECTORY + filename.replace("client_cert", "client_key")
            ret = handshake_test(host, port, test_ciphers, fips_mode, use_client_auth=True, use_client_cert=client_cert_path, use_client_key=client_key_path)
            if ret is not 0:
                failed += 1
                
    return failed

def resume_test(host, port, test_ciphers, fips_mode):
    """
    Tests s2n's session resumption capability using all valid combinations of cipher suite and TLS version.
    """
    print("\n\tRunning resumption tests:")
    failed = 0
    for ssl_version in [S2N_TLS10, S2N_TLS11, S2N_TLS12]:
        print("\n\tTesting ciphers using client version: " + S2N_PROTO_VERS_TO_STR[ssl_version])
        for cipher in test_ciphers:
            cipher_name = cipher.openssl_name
            cipher_vers = cipher.min_tls_vers

            # Skip the cipher if openssl can't test it. 3DES/RC4 are disabled by default in 1.1.0
            if not cipher.openssl_1_1_0_compatible:
                continue

            if ssl_version < cipher_vers:
                continue

            ret = try_handshake(host, port, cipher_name, ssl_version, resume=True, enter_fips_mode=fips_mode)
            result_prefix = "Cipher: %-30s Vers: %-10s ... " % (cipher_name, S2N_PROTO_VERS_TO_STR[ssl_version])
            print_result(result_prefix, ret)
            if ret != 0:
                failed = 1

    return failed

supported_sigs = ["RSA+SHA1", "RSA+SHA224", "RSA+SHA256", "RSA+SHA384", "RSA+SHA512"]
unsupported_sigs = ["ECDSA+SHA256", "ECDSA+SHA512"]

def run_sigalg_test(host, port, cipher, ssl_version, permutation, fips_mode, use_client_auth):
    # Put some unsupported algs in front to make sure we gracefully skip them
    mixed_sigs = unsupported_sigs + list(permutation)
    mixed_sigs_str = ':'.join(mixed_sigs)
    ret = try_handshake(host, port, cipher.openssl_name, ssl_version, sig_algs=mixed_sigs_str, enter_fips_mode=fips_mode, client_auth=use_client_auth)
        
    # Trim the RSA part off for brevity. User should know we are only supported RSA at the moment.
    prefix = "Digests: %-35s ClientAuth: %-6s Vers: %-8s... " % (':'.join([x[4:] for x in permutation]), str(use_client_auth), S2N_PROTO_VERS_TO_STR[S2N_TLS12])
    print_result(prefix, ret)
    return ret

def sigalg_test(host, port, fips_mode, use_client_auth=None):
    """
    Acceptance test for supported signature algorithms. Tests all possible supported sigalgs with unsupported ones mixed in
    for noise.
    """
    failed = 0

    print("\n\tRunning signature algorithm tests:")

    if fips_mode and use_client_auth:
        print("\t\033[33;1mSKIPPED\033[0m - Client Auth not supported in FIPS mode")
        return 0

    print("\tExpected supported:   " + str(supported_sigs))
    print("\tExpected unsupported: " + str(unsupported_sigs))

    for size in range(1, len(supported_sigs) + 1):
        print("\n\t\tTesting ciphers using signature preferences of size: " + str(size))
        threadpool = create_thread_pool()
        portOffset = 0
        results = []
        # Produce permutations of every accepted signature algorithm in every possible order
        for permutation in itertools.permutations(supported_sigs, size):
            for cipher in ALL_TEST_CIPHERS:
                # Try an ECDHE cipher suite and a DHE one
                if(cipher.openssl_name == "ECDHE-RSA-AES128-GCM-SHA256" or cipher.openssl_name == "DHE-RSA-AES128-GCM-SHA256"):
                    async_result = threadpool.apply_async(run_sigalg_test, (host, port + portOffset, cipher, S2N_TLS12, permutation, fips_mode, use_client_auth))
                    portOffset = portOffset + 1
                    results.append(async_result)

        threadpool.close()
        threadpool.join()
        for async_result in results:
            if async_result.get() != 0:
                failed = 1

    return failed

def elliptic_curve_test(host, port, fips_mode):
    """
    Acceptance test for supported elliptic curves. Tests all possible supported curves with unsupported curves mixed in
    for noise.
    """
    supported_curves = ["P-256", "P-384"]
    unsupported_curves = ["B-163", "K-409"]
    print("\n\tRunning elliptic curve tests:")
    print("\tExpected supported:   " + str(supported_curves))
    print("\tExpected unsupported: " + str(unsupported_curves))

    failed = 0
    for size in range(1, len(supported_curves) + 1):
        print("\n\t\tTesting ciphers using curve list of size: " + str(size))

        # Produce permutations of every accepted curve in every possible order
        for permutation in itertools.permutations(supported_curves, size):
            # Put some unsupported curves in front to make sure we gracefully skip them
            mixed_curves = unsupported_curves + list(permutation)
            mixed_curves_str = ':'.join(mixed_curves)
            for cipher in filter(lambda x: x.openssl_name == "ECDHE-RSA-AES128-GCM-SHA256" or x.openssl_name == "ECDHE-RSA-AES128-SHA", ALL_TEST_CIPHERS):
                if fips_mode and cipher.openssl_fips_compatible == False:
                    continue
                ret = try_handshake(host, port, cipher.openssl_name, S2N_TLS12, curves=mixed_curves_str, enter_fips_mode=fips_mode)
                prefix = "Curves: %-40s Vers: %10s ... " % (':'.join(list(permutation)), S2N_PROTO_VERS_TO_STR[S2N_TLS12])
                print_result(prefix, ret)
                if ret != 0:
                    failed = 1
    return failed

def elliptic_curve_fallback_test(host, port, fips_mode):
    """
    Tests graceful fallback when s2n doesn't support any curves offered by the client. A non-ecc suite should be
    negotiated.
    """
    failed = 0
    # Make sure s2n can still negotiate a non-EC kx(AES256-GCM-SHA384) suite if we don't match anything on the client
    unsupported_curves = ["B-163", "K-409"]
    ret = try_handshake(host, port, "ECDHE-RSA-AES128-SHA256:AES256-GCM-SHA384", S2N_TLS12, curves=":".join(unsupported_curves), enter_fips_mode=fips_mode)
    print_result("%-65s ... " % "Testing curve mismatch fallback", ret)
    if ret != 0:
        failed = 1

    return failed

def handshake_fragmentation_test(host, port, fips_mode):
    """
    Tests successful negotation with s_client despite message fragmentation. Max record size is clamped to force s2n
    to fragment the ServerCertifcate message.
    """
    print("\n\tRunning handshake fragmentation tests:")
    failed = 0
    for ssl_version in [S2N_TLS10, S2N_TLS11, S2N_TLS12]:
        print("\n\tTesting ciphers using client version: " + S2N_PROTO_VERS_TO_STR[ssl_version])
        # Cipher isn't relevant for this test, pick one available in all OpenSSL versions and all TLS versions
        cipher_name = "AES256-SHA"

        # Low latency option indirectly forces fragmentation.
        ret = try_handshake(host, port, cipher_name, ssl_version, prefer_low_latency=True, enter_fips_mode=fips_mode)
        result_prefix = "Cipher: %-30s Vers: %-10s ... " % (cipher_name, S2N_PROTO_VERS_TO_STR[ssl_version])
        print_result(result_prefix, ret)
        if ret != 0:
            failed = 1

    failed = 0
    return failed

def ocsp_stapling_test(host, port, fips_mode):
    """
    Test s2n's server OCSP stapling capability
    """
    print("\n\tRunning OCSP stapling tests:")
    failed = 0
    for ssl_version in [S2N_TLS10, S2N_TLS11, S2N_TLS12]:
        print("\n\tTesting ciphers using client version: " + S2N_PROTO_VERS_TO_STR[ssl_version])
        # Cipher isn't relevant for this test, pick one available in all TLS versions
        cipher_name = "AES256-SHA"

        ret = try_handshake(host, port, cipher_name, ssl_version, enter_fips_mode=fips_mode, server_cert=TEST_OCSP_CERT, server_key=TEST_OCSP_KEY,
                ocsp=TEST_OCSP_RESPONSE_FILE)
        result_prefix = "Cipher: %-30s Vers: %-10s ... " % (cipher_name, S2N_PROTO_VERS_TO_STR[ssl_version])
        print_result(result_prefix, ret)
        if ret != 0:
            failed = 1

    return failed

def main():
    parser = argparse.ArgumentParser(description='Runs TLS server integration tests against s2nd using Openssl s_client')
    parser.add_argument('host', help='The host for s2nd to bind to')
    parser.add_argument('port', type=int, help='The port for s2nd to bind to')
    parser.add_argument('--libcrypto', default='openssl-1.1.0', choices=['openssl-1.0.2', 'openssl-1.0.2-fips', 'openssl-1.1.0', 'openssl-1.1.x-master', 'libressl'],
            help="""The Libcrypto that s2n was built with. s2n supports different cipher suites depending on
                    libcrypto version. Defaults to openssl-1.1.0.""")
    args = parser.parse_args()

    # Retrieve the test ciphers to use based on the libcrypto version s2n was built with
    test_ciphers = S2N_LIBCRYPTO_TO_TEST_CIPHERS[args.libcrypto]
    host = args.host
    port = args.port

    fips_mode = False
    if environ.get("S2N_TEST_IN_FIPS_MODE") is not None:
        fips_mode = True
        print("\nRunning s2nd in FIPS mode.")

    print("\nRunning tests with: " + os.popen('openssl version').read())

    failed = 0
    failed += resume_test(host, port, test_ciphers, fips_mode)
    failed += handshake_test(host, port, test_ciphers, fips_mode)
    failed += client_auth_test(host, port, test_ciphers, fips_mode)
    failed += sigalg_test(host, port, fips_mode)
    failed += sigalg_test(host, port, fips_mode, use_client_auth=True)
    failed += elliptic_curve_test(host, port, fips_mode)
    failed += elliptic_curve_fallback_test(host, port, fips_mode)
    failed += handshake_fragmentation_test(host, port, fips_mode)
    failed += ocsp_stapling_test(host, port, fips_mode)
    return failed

if __name__ == "__main__":
    sys.exit(main())

